9025. k-th element

 

Array à of n integers and number k are given. Find the k-th element in a sorted array a (indexing starts from 1).

 

Input. First line contains two integers n and k (1 ≤ n ≤ 2000, 1 ≤ k ≤ n). Second line contains n integers ai (1 ≤ i ≤ n, 1 ≤ ai ≤ 2000).

 

Output. Print the k-th element in a sorted array à.

 

Sample input 1

Sample output 1

2 1

2 1

1

 

 

Sample input 2

Sample output 2

5 3

4 7 1 8 12

7

 

 

SOLUTION

k-th statistics

 

Algorithm analysis

To solve the problem in O(nlog2n), it is enough to sort the array and print its k-th element.

 

We can use the nth_element function, which in O(n) permutes the elements of the array in such a way that the k-th element will be in the k-th place, the numbers to the left of it are no more than a[k], and the numbers to the right of it are at least a[k].

 

The k-th statistic can be found in linear time using the partition function, which is used in quicksort algorithm. The partition function in linear time splits (does not sort) the array a[1..n] into two parts a[1..pos] and a[pos + 1..n] so that all elements of the array from the first part are no more than elements from the second part. If kpos, then we look for the k-th statistics in a[1..pos], otherwise we look for it in a[pos + 1..n].

 

Algorithm realization

Declare the array.

 

vector<int> v;

 

Read the input data.

 

scanf("%d %d", &n, &k);

v.resize(n + 1);

for (i = 1; i <= n; i++)

  scanf("%d", &v[i]);

 

Sort array starting from the first index.

 

sort(v.begin() + 1, v.end());

 

Print the k-th element.

 

printf("%d\n", v[k]);

 

Algorithm realization – nth_element

 

#include <cstdio>

#include <vector>

#include <algorithm>

using namespace std;

 

vector<int> v;

int n, k, i;

 

int main(void)

{

  scanf("%d %d", &n, &k);

  v.resize(n + 1);

  for (i = 1; i <= n; i++)

    scanf("%d", &v[i]);

 

  nth_element(v.begin() + 1, v.begin() + k, v.end());

  printf("%d\n", v[k]);

  return 0;

}

 

Algorithm realization – k-th statistic

 

#include <cstdio>

#include <vector>

#include <algorithm>

using namespace std;

 

vector<int> v;

int n, k, i;

 

int Partition(int left, int right)

{

  int x = v[left], i = left - 1, j = right + 1;

  while (1)

  {

    do j--; while (v[j] > x);

    do i++; while (v[i] < x);

    if (i < j) swap(v[i], v[j]); else return j;

  }

}

 

int kth(int k, int left, int right)

{

  if (left == right) return v[left];

  int pos = Partition(left, right);

  if (k <= pos) return kth(k, left, pos);

  else return kth(k, pos + 1, right);

}

 

int main(void)

{

  scanf("%d %d", &n, &k);

  v.resize(n + 1);

  for (i = 1; i <= n; i++)

    scanf("%d", &v[i]);

 

  printf("%d\n", kth(k, 1, n));

  return 0;

}

 

Java realization

 

import java.util.*;

 

public class Main

{

  static void swap(int a[], int i, int j)

  {

    int temp = a[i];  a[i] = a[j]; a[j] = temp;

  }

 

  static int Partition(int a[], int L, int R)

  {

    int x = a[L];

    int i = L - 1, j = R + 1;

    while (true)

    {

      do j--; while (a[j] > x);

      do i++; while (a[i] < x);

      if (i < j) swap(a, i, j); else return j;

    }

  }

 

  static int kth(int a[], int k, int left, int right)

  {

    if (left == right) return a[left];

    int pos = Partition(a, left, right);

    if (k <= pos) return kth(a, k, left, pos);

    else return kth(a, k, pos + 1, right);

  }

 

  public static void main(String[] args)

  {

    Scanner con = new Scanner(System.in);

    int n = con.nextInt();

    int k = con.nextInt();

    int[] m = new int[n+1];

 

    for (int i = 1; i <= n; i++)

      m[i] = con.nextInt();

 

    System.out.println(kth(m, k, 1, n));

    con.close();

  }

}

 

Python realization

 

def Partition(lst, left, right):

  x = lst[left]

  i = left – 1

  j = right + 1

  while (True):

    while(True):

      j -= 1;

      if lst[j] <= x: break

    while(True):

      i += 1;

      if lst[i] >= x: break

    if i < j: lst[i], lst[j] =  lst[j], lst[i]

    else: return j;

 

def kth(lst, k, left, right):

  if left == right: return lst[left];

  pos = Partition(lst, left, right);

  if k <= pos: return kth(lst, k, left, pos);

  else: return kth(lst, k, pos + 1, right);

 

n, k = map(int,input().split())

lst = list(map(int,input().split()))

res = kth(lst, k - 1, 0, n - 1)

print(res)